{
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3",
   "language": "python"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.13",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  },
  "kaggle": {
   "accelerator": "nvidiaTeslaT4",
   "dataSources": [],
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook",
   "isGpuEnabled": true
  }
 },
 "nbformat_minor": 4,
 "nbformat": 4,
 "cells": [
  {
   "cell_type": "code",
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "    \n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n",
    "\n",
    "seed = 42\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:12.527334Z",
     "iopub.execute_input": "2024-05-03T04:47:12.527802Z",
     "iopub.status.idle": "2024-05-03T04:47:17.362854Z",
     "shell.execute_reply.started": "2024-05-03T04:47:12.527757Z",
     "shell.execute_reply": "2024-05-03T04:47:17.361864Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.588067Z",
     "start_time": "2025-03-10T14:46:19.458312Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=12, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.10.0\n",
      "numpy 1.26.4\n",
      "pandas 2.2.3\n",
      "sklearn 1.6.1\n",
      "torch 2.6.0+cu118\n",
      "cuda:0\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 准备数据"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "source": [
    "from tensorflow import keras\n",
    "#用karas有的数据集imdb，电影分类,分电影是积极的，还是消极的\n",
    "imdb = keras.datasets.imdb\n",
    "#载入数据使用下面两个参数\n",
    "vocab_size = 10000  #词典大小，仅保留训练数据中前10000个最经常出现的单词，低频单词被舍弃\n",
    "index_from = 3  #0,1,2,3空出来做别的事\n",
    "#前一万个词出现词频最高的会保留下来进行处理，后面的作为特殊字符处理，\n",
    "# 小于3的id都是特殊字符，下面代码有写\n",
    "# 需要注意的一点是取出来的词表还是从1开始的，需要做处理\n",
    "(train_data, train_labels), (test_data, test_labels) = imdb.load_data(\n",
    "    num_words = vocab_size, index_from = index_from)"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:17.364755Z",
     "iopub.execute_input": "2024-05-03T04:47:17.365247Z",
     "iopub.status.idle": "2024-05-03T04:47:34.391998Z",
     "shell.execute_reply.started": "2024-05-03T04:47:17.365218Z",
     "shell.execute_reply": "2024-05-03T04:47:34.391210Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.699599Z",
     "start_time": "2025-03-10T14:46:28.589072Z"
    }
   },
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'tensorflow'",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mModuleNotFoundError\u001B[0m                       Traceback (most recent call last)",
      "Cell \u001B[1;32mIn[2], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21;01mtensorflow\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m keras\n\u001B[0;32m      2\u001B[0m \u001B[38;5;66;03m#用karas有的数据集imdb，电影分类,分电影是积极的，还是消极的\u001B[39;00m\n\u001B[0;32m      3\u001B[0m imdb \u001B[38;5;241m=\u001B[39m keras\u001B[38;5;241m.\u001B[39mdatasets\u001B[38;5;241m.\u001B[39mimdb\n",
      "\u001B[1;31mModuleNotFoundError\u001B[0m: No module named 'tensorflow'"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "source": [
    "\n",
    "print(\"train\", len(train_data), train_labels.shape)  #25000个样本，每个样本是一段话，每个单词用一个数字表示\n",
    "print(\"test\", len(test_data), test_labels.shape)"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:34.393563Z",
     "iopub.execute_input": "2024-05-03T04:47:34.394093Z",
     "iopub.status.idle": "2024-05-03T04:47:34.399413Z",
     "shell.execute_reply.started": "2024-05-03T04:47:34.394068Z",
     "shell.execute_reply": "2024-05-03T04:47:34.398409Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.701608Z",
     "start_time": "2025-03-10T14:46:28.700607Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "#载入词表，看下词表长度，词表就像英语字典\n",
    "word_index = imdb.get_word_index()\n",
    "print(len(word_index))\n",
    "print(type(word_index))\n",
    "#词表虽然有8万多，但是我们只载入了最高频的1万词！！！！"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:34.402571Z",
     "iopub.execute_input": "2024-05-03T04:47:34.403158Z",
     "iopub.status.idle": "2024-05-03T04:47:34.535124Z",
     "shell.execute_reply.started": "2024-05-03T04:47:34.403131Z",
     "shell.execute_reply": "2024-05-03T04:47:34.534127Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.702612Z",
     "start_time": "2025-03-10T14:46:28.702612Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 构造 word2idx 和 idx2word"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "source": [
    "word2idx = {word: idx + 3 for word, idx in word_index.items()}\n",
    "word2idx.update({\n",
    "    \"[PAD]\": 0,     # 填充 token\n",
    "    \"[BOS]\": 1,     # begin of sentence\n",
    "    \"[UNK]\": 2,     # 未知 token\n",
    "    \"[EOS]\": 3,     # end of sentence\n",
    "})\n",
    "\n",
    "idx2word = {idx: word for word, idx in word2idx.items()}"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:34.536374Z",
     "iopub.execute_input": "2024-05-03T04:47:34.536752Z",
     "iopub.status.idle": "2024-05-03T04:47:34.579550Z",
     "shell.execute_reply.started": "2024-05-03T04:47:34.536716Z",
     "shell.execute_reply": "2024-05-03T04:47:34.578607Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.703609Z",
     "start_time": "2025-03-10T14:46:28.703609Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "# 选择 max_length\n",
    "length_collect = {}\n",
    "for text in train_data:\n",
    "    length = len(text)\n",
    "    length_collect[length] = length_collect.get(length, 0) + 1\n",
    "    \n",
    "MAX_LENGTH = 500\n",
    "plt.bar(length_collect.keys(), length_collect.values())\n",
    "plt.axvline(MAX_LENGTH, label=\"max length\", c=\"gray\", ls=\":\")\n",
    "plt.legend()\n",
    "plt.show()"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:34.580697Z",
     "iopub.execute_input": "2024-05-03T04:47:34.581037Z",
     "iopub.status.idle": "2024-05-03T04:47:36.765733Z",
     "shell.execute_reply.started": "2024-05-03T04:47:34.581009Z",
     "shell.execute_reply": "2024-05-03T04:47:36.764837Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.704609Z",
     "start_time": "2025-03-10T14:46:28.704609Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Tokenizer"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "source": [
    "class Tokenizer:\n",
    "    def __init__(self, word2idx, idx2word, max_length=500, pad_idx=0, bos_idx=1, eos_idx=3, unk_idx=2):\n",
    "        self.word2idx = word2idx\n",
    "        self.idx2word = idx2word\n",
    "        self.max_length = max_length\n",
    "        self.pad_idx = pad_idx\n",
    "        self.bos_idx = bos_idx\n",
    "        self.eos_idx = eos_idx\n",
    "        self.unk_idx = unk_idx\n",
    "    \n",
    "    def encode(self, text_list, padding_first=False):\n",
    "        \"\"\"如果padding_first == True，则padding加载前面，否则加载后面\"\"\"\n",
    "        max_length = min(self.max_length, 2 + max([len(text) for text in text_list]))\n",
    "        indices_list = []\n",
    "        for text in text_list:\n",
    "            indices = [self.bos_idx] + [self.word2idx.get(word, self.unk_idx) for word in text[:max_length-2]] + [self.eos_idx]\n",
    "            if padding_first:\n",
    "                indices = [self.pad_idx] * (max_length - len(indices)) + indices\n",
    "            else:\n",
    "                indices = indices + [self.pad_idx] * (max_length - len(indices))\n",
    "            indices_list.append(indices)\n",
    "        return torch.tensor(indices_list)\n",
    "    \n",
    "    \n",
    "    def decode(self, indices_list, remove_bos=True, remove_eos=True, remove_pad=True, split=False):\n",
    "        text_list = []\n",
    "        for indices in indices_list:\n",
    "            text = []\n",
    "            for index in indices:\n",
    "                word = self.idx2word.get(index, \"[UNK]\")\n",
    "                if remove_bos and word == \"[BOS]\":\n",
    "                    continue\n",
    "                if remove_eos and word == \"[EOS]\":\n",
    "                    break\n",
    "                if remove_pad and word == \"[PAD]\":\n",
    "                    break\n",
    "                text.append(word)\n",
    "            text_list.append(\" \".join(text) if not split else text)\n",
    "        return text_list\n",
    "    \n",
    "\n",
    "tokenizer = Tokenizer(word2idx=word2idx, idx2word=idx2word)\n",
    "raw_text = [\"hello world\".split(), \"tokenize text datas with batch\".split(), \"this is a test\".split()]\n",
    "indices = tokenizer.encode(raw_text, padding_first=True)\n",
    "decode_text = tokenizer.decode(indices.tolist(), remove_bos=False, remove_eos=False, remove_pad=False)\n",
    "print(\"raw text\")\n",
    "for raw in raw_text:\n",
    "    print(raw)\n",
    "print(\"indices\")\n",
    "for index in indices:\n",
    "    print(index)\n",
    "print(\"decode text\")\n",
    "for decode in decode_text:\n",
    "    print(decode)"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:36.767290Z",
     "iopub.execute_input": "2024-05-03T04:47:36.767595Z",
     "iopub.status.idle": "2024-05-03T04:47:36.811355Z",
     "shell.execute_reply.started": "2024-05-03T04:47:36.767569Z",
     "shell.execute_reply": "2024-05-03T04:47:36.810435Z"
    },
    "trusted": true
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "# 看看训练集的数据\n",
    "\n",
    "tokenizer.decode(train_data[:2], remove_bos=False, remove_eos=False, remove_pad=False)"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:36.812516Z",
     "iopub.execute_input": "2024-05-03T04:47:36.812799Z",
     "iopub.status.idle": "2024-05-03T04:47:36.819435Z",
     "shell.execute_reply.started": "2024-05-03T04:47:36.812774Z",
     "shell.execute_reply": "2024-05-03T04:47:36.818339Z"
    },
    "trusted": true
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 数据集与 DataLoader"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "class IMDBDataset(Dataset):\n",
    "    def __init__(self, data, labels, remain_length=True):\n",
    "        if remain_length:\n",
    "            self.data = tokenizer.decode(data, remove_bos=False, remove_eos=False, remove_pad=False)\n",
    "        else:\n",
    "            # 缩减一下数据\n",
    "            self.data = tokenizer.decode(data)\n",
    "        self.labels = labels\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        text = self.data[index]\n",
    "        label = self.labels[index]\n",
    "        return text, label\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "    \n",
    "    \n",
    "def collate_fct(batch):\n",
    "    text_list = [item[0].split() for item in batch]\n",
    "    label_list = [item[1] for item in batch]\n",
    "    # 这里使用 padding first\n",
    "    text_list = tokenizer.encode(text_list, padding_first=True).to(dtype=torch.int)\n",
    "    return text_list, torch.tensor(label_list).reshape(-1, 1).to(dtype=torch.float)\n",
    "\n",
    "\n",
    "# 用RNN，缩短序列长度\n",
    "train_ds = IMDBDataset(train_data, train_labels, remain_length=False)\n",
    "test_ds = IMDBDataset(test_data, test_labels, remain_length=False)"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:36.820601Z",
     "iopub.execute_input": "2024-05-03T04:47:36.820915Z",
     "iopub.status.idle": "2024-05-03T04:47:42.935725Z",
     "shell.execute_reply.started": "2024-05-03T04:47:36.820890Z",
     "shell.execute_reply": "2024-05-03T04:47:42.934882Z"
    },
    "trusted": true
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "batch_size = 128\n",
    "train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fct)\n",
    "test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fct)"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:42.940057Z",
     "iopub.execute_input": "2024-05-03T04:47:42.940354Z",
     "iopub.status.idle": "2024-05-03T04:47:42.945758Z",
     "shell.execute_reply.started": "2024-05-03T04:47:42.940327Z",
     "shell.execute_reply": "2024-05-03T04:47:42.944620Z"
    },
    "trusted": true
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 定义模型"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "source": [
    "class LSTM(nn.Module):\n",
    "    def __init__(self, embedding_dim=16, hidden_dim=64, vocab_size=vocab_size, num_layers=1, bidirectional=False):\n",
    "        super(LSTM, self).__init__()\n",
    "        self.embeding = nn.Embedding(vocab_size, embedding_dim)\n",
    "        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True, bidirectional=bidirectional)\n",
    "        self.layer = nn.Linear(hidden_dim * (2 if bidirectional else 1), hidden_dim)\n",
    "        self.fc = nn.Linear(hidden_dim, 1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        # [bs, seq length]\n",
    "        x = self.embeding(x)\n",
    "        # [bs, seq length, embedding_dim] -> shape [bs, embedding_dim, seq length]\n",
    "        seq_output, (hidden, cell) = self.lstm(x)\n",
    "        # [bs, seq length, hidden_dim], [*, bs, hidden_dim]\n",
    "        x = seq_output[:, -1, :]\n",
    "        # 取最后一个时间步的输出 (这也是为什么要设置padding_first=True的原因)\n",
    "        x = self.layer(x)\n",
    "        x = self.fc(x)\n",
    "        return x\n",
    "    \n",
    "sample_inputs = torch.randint(0, vocab_size, (2, 128))\n",
    "    \n",
    "print(\"{:=^80}\".format(\" 一层单向 LSTM \"))       \n",
    "for key, value in LSTM().named_parameters():\n",
    "    print(f\"{key:^40}paramerters num: {np.prod(value.shape)}\")\n",
    "\n",
    "    \n",
    "print(\"{:=^80}\".format(\" 一层双向 LSTM \"))       \n",
    "for key, value in LSTM(bidirectional=True).named_parameters():\n",
    "    print(f\"{key:^40}paramerters num: {np.prod(value.shape)}\")\n",
    "\n",
    "    \n",
    "print(\"{:=^80}\".format(\" 两层单向 LSTM \"))       \n",
    "for key, value in LSTM(num_layers=2).named_parameters():\n",
    "    print(f\"{key:^40}paramerters num: {np.prod(value.shape)}\")\n"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:42.947030Z",
     "iopub.execute_input": "2024-05-03T04:47:42.947353Z",
     "iopub.status.idle": "2024-05-03T04:47:42.981775Z",
     "shell.execute_reply.started": "2024-05-03T04:47:42.947327Z",
     "shell.execute_reply": "2024-05-03T04:47:42.980874Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.709607Z",
     "start_time": "2025-03-10T14:46:28.708609Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "4 * 16 * 64 #lstm.weight_ih_l0"
   ],
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:42.982885Z",
     "iopub.execute_input": "2024-05-03T04:47:42.983224Z",
     "iopub.status.idle": "2024-05-03T04:47:42.989768Z",
     "shell.execute_reply.started": "2024-05-03T04:47:42.983190Z",
     "shell.execute_reply": "2024-05-03T04:47:42.988745Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.710608Z",
     "start_time": "2025-03-10T14:46:28.710608Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "4* 64*64 #lstm.weight_hh_l0"
   ],
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:42.991003Z",
     "iopub.execute_input": "2024-05-03T04:47:42.991308Z",
     "iopub.status.idle": "2024-05-03T04:47:42.998461Z",
     "shell.execute_reply.started": "2024-05-03T04:47:42.991280Z",
     "shell.execute_reply": "2024-05-03T04:47:42.997549Z"
    },
    "trusted": true
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 训练"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluating(model, dataloader, loss_fct):\n",
    "    loss_list = []\n",
    "    pred_list = []\n",
    "    label_list = []\n",
    "    for datas, labels in dataloader:\n",
    "        datas = datas.to(device)\n",
    "        labels = labels.to(device)\n",
    "        # 前向计算\n",
    "        logits = model(datas)\n",
    "        loss = loss_fct(logits, labels)         # 验证集损失\n",
    "        loss_list.append(loss.item())\n",
    "        # 二分类\n",
    "        preds = logits > 0\n",
    "        pred_list.extend(preds.cpu().numpy().tolist())\n",
    "        label_list.extend(labels.cpu().numpy().tolist())\n",
    "        \n",
    "    acc = accuracy_score(label_list, pred_list)\n",
    "    return np.mean(loss_list), acc\n"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:42.999576Z",
     "iopub.execute_input": "2024-05-03T04:47:42.999917Z",
     "iopub.status.idle": "2024-05-03T04:47:43.091789Z",
     "shell.execute_reply.started": "2024-05-03T04:47:42.999882Z",
     "shell.execute_reply": "2024-05-03T04:47:43.091056Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.712611Z",
     "start_time": "2025-03-10T14:46:28.712611Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "### TensorBoard 可视化\n",
    "\n",
    "\n",
    "训练过程中可以使用如下命令启动tensorboard服务。\n",
    "\n",
    "```shell\n",
    "tensorboard \\\n",
    "    --logdir=runs \\     # log 存放路径\n",
    "    --host 0.0.0.0 \\    # ip\n",
    "    --port 8848         # 端口\n",
    "```"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "source": [
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "class TensorBoardCallback:\n",
    "    def __init__(self, log_dir, flush_secs=10):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            log_dir (str): dir to write log.\n",
    "            flush_secs (int, optional): write to dsk each flush_secs seconds. Defaults to 10.\n",
    "        \"\"\"\n",
    "        self.writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)\n",
    "\n",
    "    def draw_model(self, model, input_shape):\n",
    "        self.writer.add_graph(model, input_to_model=torch.randn(input_shape))\n",
    "        \n",
    "    def add_loss_scalars(self, step, loss, val_loss):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/loss\", \n",
    "            tag_scalar_dict={\"loss\": loss, \"val_loss\": val_loss},\n",
    "            global_step=step,\n",
    "            )\n",
    "        \n",
    "    def add_acc_scalars(self, step, acc, val_acc):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/accuracy\",\n",
    "            tag_scalar_dict={\"accuracy\": acc, \"val_accuracy\": val_acc},\n",
    "            global_step=step,\n",
    "        )\n",
    "        \n",
    "    def add_lr_scalars(self, step, learning_rate):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/learning_rate\",\n",
    "            tag_scalar_dict={\"learning_rate\": learning_rate},\n",
    "            global_step=step,\n",
    "            \n",
    "        )\n",
    "    \n",
    "    def __call__(self, step, **kwargs):\n",
    "        # add loss\n",
    "        loss = kwargs.pop(\"loss\", None)\n",
    "        val_loss = kwargs.pop(\"val_loss\", None)\n",
    "        if loss is not None and val_loss is not None:\n",
    "            self.add_loss_scalars(step, loss, val_loss)\n",
    "        # add acc\n",
    "        acc = kwargs.pop(\"acc\", None)\n",
    "        val_acc = kwargs.pop(\"val_acc\", None)\n",
    "        if acc is not None and val_acc is not None:\n",
    "            self.add_acc_scalars(step, acc, val_acc)\n",
    "        # add lr\n",
    "        learning_rate = kwargs.pop(\"lr\", None)\n",
    "        if learning_rate is not None:\n",
    "            self.add_lr_scalars(step, learning_rate)\n"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:43.092913Z",
     "iopub.execute_input": "2024-05-03T04:47:43.093218Z",
     "iopub.status.idle": "2024-05-03T04:47:43.118190Z",
     "shell.execute_reply.started": "2024-05-03T04:47:43.093169Z",
     "shell.execute_reply": "2024-05-03T04:47:43.117453Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.713611Z",
     "start_time": "2025-03-10T14:46:28.713611Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Save Best\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "source": [
    "class SaveCheckpointsCallback:\n",
    "    def __init__(self, save_dir, save_step=5000, save_best_only=True):\n",
    "        \"\"\"\n",
    "        Save checkpoints each save_epoch epoch. \n",
    "        We save checkpoint by epoch in this implementation.\n",
    "        Usually, training scripts with pytorch evaluating model and save checkpoint by step.\n",
    "\n",
    "        Args:\n",
    "            save_dir (str): dir to save checkpoint\n",
    "            save_epoch (int, optional): the frequency to save checkpoint. Defaults to 1.\n",
    "            save_best_only (bool, optional): If True, only save the best model or save each model at every epoch.\n",
    "        \"\"\"\n",
    "        self.save_dir = save_dir\n",
    "        self.save_step = save_step\n",
    "        self.save_best_only = save_best_only\n",
    "        self.best_metrics = -1\n",
    "        \n",
    "        # mkdir\n",
    "        if not os.path.exists(self.save_dir):\n",
    "            os.mkdir(self.save_dir)\n",
    "        \n",
    "    def __call__(self, step, state_dict, metric=None):\n",
    "        if step % self.save_step > 0:\n",
    "            return\n",
    "        \n",
    "        if self.save_best_only:\n",
    "            assert metric is not None\n",
    "            if metric >= self.best_metrics:\n",
    "                # save checkpoints\n",
    "                torch.save(state_dict, os.path.join(self.save_dir, \"best.ckpt\"))\n",
    "                # update best metrics\n",
    "                self.best_metrics = metric\n",
    "        else:\n",
    "            torch.save(state_dict, os.path.join(self.save_dir, f\"{step}.ckpt\"))\n",
    "\n"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:43.119248Z",
     "iopub.execute_input": "2024-05-03T04:47:43.119557Z",
     "iopub.status.idle": "2024-05-03T04:47:43.129100Z",
     "shell.execute_reply.started": "2024-05-03T04:47:43.119532Z",
     "shell.execute_reply": "2024-05-03T04:47:43.128085Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.714609Z",
     "start_time": "2025-03-10T14:46:28.714609Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Early Stop"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "source": [
    "class EarlyStopCallback:\n",
    "    def __init__(self, patience=5, min_delta=0.01):\n",
    "        \"\"\"\n",
    "\n",
    "        Args:\n",
    "            patience (int, optional): Number of epochs with no improvement after which training will be stopped.. Defaults to 5.\n",
    "            min_delta (float, optional): Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute \n",
    "                change of less than min_delta, will count as no improvement. Defaults to 0.01.\n",
    "        \"\"\"\n",
    "        self.patience = patience\n",
    "        self.min_delta = min_delta\n",
    "        self.best_metric = -1\n",
    "        self.counter = 0\n",
    "        \n",
    "    def __call__(self, metric):\n",
    "        if metric >= self.best_metric + self.min_delta:\n",
    "            # update best metric\n",
    "            self.best_metric = metric\n",
    "            # reset counter \n",
    "            self.counter = 0\n",
    "        else: \n",
    "            self.counter += 1\n",
    "            \n",
    "    @property\n",
    "    def early_stop(self):\n",
    "        return self.counter >= self.patience\n"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:43.130506Z",
     "iopub.execute_input": "2024-05-03T04:47:43.130817Z",
     "iopub.status.idle": "2024-05-03T04:47:43.140541Z",
     "shell.execute_reply.started": "2024-05-03T04:47:43.130781Z",
     "shell.execute_reply": "2024-05-03T04:47:43.139593Z"
    },
    "trusted": true
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "# 训练\n",
    "def training(\n",
    "    model, \n",
    "    train_loader, \n",
    "    val_loader, \n",
    "    epoch, \n",
    "    loss_fct, \n",
    "    optimizer, \n",
    "    tensorboard_callback=None,\n",
    "    save_ckpt_callback=None,\n",
    "    early_stop_callback=None,\n",
    "    eval_step=500,\n",
    "    ):\n",
    "    record_dict = {\n",
    "        \"train\": [],\n",
    "        \"val\": []\n",
    "    }\n",
    "    \n",
    "    global_step = 0\n",
    "    model.train()\n",
    "    with tqdm(total=epoch * len(train_loader)) as pbar:\n",
    "        for epoch_id in range(epoch):\n",
    "            # training\n",
    "            for datas, labels in train_loader:\n",
    "                datas = datas.to(device)\n",
    "                labels = labels.to(device)\n",
    "                # 梯度清空\n",
    "                optimizer.zero_grad()\n",
    "                # 模型前向计算\n",
    "                logits = model(datas)\n",
    "                # 计算损失\n",
    "                loss = loss_fct(logits, labels)\n",
    "                # 梯度回传\n",
    "                loss.backward()\n",
    "                # 调整优化器，包括学习率的变动等\n",
    "                optimizer.step()\n",
    "                preds = logits > 0\n",
    "            \n",
    "                acc = accuracy_score(labels.cpu().numpy(), preds.cpu().numpy())    \n",
    "                loss = loss.cpu().item()\n",
    "                # record\n",
    "                \n",
    "                record_dict[\"train\"].append({\n",
    "                    \"loss\": loss, \"acc\": acc, \"step\": global_step\n",
    "                })\n",
    "                \n",
    "                # evaluating\n",
    "                if global_step % eval_step == 0:\n",
    "                    model.eval()\n",
    "                    val_loss, val_acc = evaluating(model, val_loader, loss_fct)\n",
    "                    record_dict[\"val\"].append({\n",
    "                        \"loss\": val_loss, \"acc\": val_acc, \"step\": global_step\n",
    "                    })\n",
    "                    model.train()\n",
    "                    \n",
    "                    # 1. 使用 tensorboard 可视化\n",
    "                    if tensorboard_callback is not None:\n",
    "                        tensorboard_callback(\n",
    "                            global_step, \n",
    "                            loss=loss, val_loss=val_loss,\n",
    "                            acc=acc, val_acc=val_acc,\n",
    "                            lr=optimizer.param_groups[0][\"lr\"],\n",
    "                            )\n",
    "                \n",
    "                    # 2. 保存模型权重 save model checkpoint\n",
    "                    if save_ckpt_callback is not None:\n",
    "                        save_ckpt_callback(global_step, model.state_dict(), metric=val_acc)\n",
    "\n",
    "                    # 3. 早停 Early Stop\n",
    "                    if early_stop_callback is not None:\n",
    "                        early_stop_callback(val_acc)\n",
    "                        if early_stop_callback.early_stop:\n",
    "                            print(f\"Early stop at epoch {epoch_id} / global_step {global_step}\")\n",
    "                            return record_dict\n",
    "                    \n",
    "                # udate step\n",
    "                global_step += 1\n",
    "                pbar.update(1)\n",
    "                pbar.set_postfix({\"epoch\": epoch_id})\n",
    "        \n",
    "    return record_dict\n",
    "        \n",
    "\n",
    "epoch = 20\n",
    "\n",
    "model = LSTM()\n",
    "\n",
    "# 1. 定义损失函数 采用交叉熵损失 (但是二分类)\n",
    "loss_fct = F.binary_cross_entropy_with_logits\n",
    "# 2. 定义优化器 采用 adam\n",
    "# Optimizers specified in the torch.optim package\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# 1. tensorboard 可视化\n",
    "if not os.path.exists(\"runs\"):\n",
    "    os.mkdir(\"runs\")\n",
    "tensorboard_callback = TensorBoardCallback(\"runs/imdb-lstm\")\n",
    "# tensorboard_callback.draw_model(model, [1, MAX_LENGTH])\n",
    "# 2. save best\n",
    "if not os.path.exists(\"checkpoints\"):\n",
    "    os.makedirs(\"checkpoints\")\n",
    "save_ckpt_callback = SaveCheckpointsCallback(\"checkpoints/imdb-lstm\", save_step=len(train_dl), save_best_only=True)\n",
    "# 3. early stop\n",
    "early_stop_callback = EarlyStopCallback(patience=10)\n",
    "\n",
    "\n",
    "# 如果有可用的多个GPU\n",
    "# if torch.cuda.device_count() > 1:\n",
    "#     print(\"使用多个GPU进行训练...\")\n",
    "#     model = nn.DataParallel(model)\n",
    "\n"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:43.142057Z",
     "iopub.execute_input": "2024-05-03T04:47:43.142394Z",
     "iopub.status.idle": "2024-05-03T04:47:44.062187Z",
     "shell.execute_reply.started": "2024-05-03T04:47:43.142363Z",
     "shell.execute_reply": "2024-05-03T04:47:44.061381Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.717609Z",
     "start_time": "2025-03-10T14:46:28.717609Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "model = model.to(device)\n",
    "start=time.time()\n",
    "record = training(\n",
    "    model, \n",
    "    train_dl, \n",
    "    test_dl, \n",
    "    epoch, \n",
    "    loss_fct, \n",
    "    optimizer, \n",
    "    tensorboard_callback=tensorboard_callback,\n",
    "    save_ckpt_callback=save_ckpt_callback,\n",
    "    early_stop_callback=early_stop_callback,\n",
    "    eval_step=len(train_dl)\n",
    "    )\n",
    "end=time.time()\n",
    "print(f'use time{end-start}')"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:47:44.063276Z",
     "iopub.execute_input": "2024-05-03T04:47:44.063556Z",
     "iopub.status.idle": "2024-05-03T04:53:29.669289Z",
     "shell.execute_reply.started": "2024-05-03T04:47:44.063531Z",
     "shell.execute_reply": "2024-05-03T04:53:29.668367Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.718608Z",
     "start_time": "2025-03-10T14:46:28.718608Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "#画线要注意的是损失是不一定在零到1之间的\n",
    "def plot_learning_curves(record_dict, sample_step=500):\n",
    "    # build DataFrame\n",
    "    train_df = pd.DataFrame(record_dict[\"train\"]).set_index(\"step\").iloc[::sample_step]\n",
    "    val_df = pd.DataFrame(record_dict[\"val\"]).set_index(\"step\")\n",
    "\n",
    "    # plot\n",
    "    fig_num = len(train_df.columns)\n",
    "    fig, axs = plt.subplots(1, fig_num, figsize=(5 * fig_num, 5))\n",
    "    for idx, item in enumerate(train_df.columns):    \n",
    "        axs[idx].plot(train_df.index, train_df[item], label=f\"train_{item}\")\n",
    "        axs[idx].plot(val_df.index, val_df[item], label=f\"val_{item}\")\n",
    "        axs[idx].grid()\n",
    "        axs[idx].legend()\n",
    "        # axs[idx].set_xticks(range(0, train_df.index[-1], 5000))\n",
    "        # axs[idx].set_xticklabels(map(lambda x: f\"{int(x/1000)}k\", range(0, train_df.index[-1], 5000)))\n",
    "        axs[idx].set_xlabel(\"step\")\n",
    "    \n",
    "    plt.show()\n",
    "\n",
    "plot_learning_curves(record, sample_step=10)  #横坐标是 steps"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:53:29.670641Z",
     "iopub.execute_input": "2024-05-03T04:53:29.671094Z",
     "iopub.status.idle": "2024-05-03T04:53:30.081623Z",
     "shell.execute_reply.started": "2024-05-03T04:53:29.671059Z",
     "shell.execute_reply": "2024-05-03T04:53:30.080600Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.719609Z",
     "start_time": "2025-03-10T14:46:28.719609Z"
    }
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "# 评估"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "source": [
    "# dataload for evaluating\n",
    "\n",
    "# load checkpoints\n",
    "model.load_state_dict(torch.load(\"checkpoints/imdb-lstm/best.ckpt\", map_location=\"cpu\"))\n",
    "\n",
    "model.eval()\n",
    "loss, acc = evaluating(model, test_dl, loss_fct)\n",
    "print(f\"loss:     {loss:.4f}\\naccuracy: {acc:.4f}\")"
   ],
   "metadata": {
    "execution": {
     "iopub.status.busy": "2024-05-03T04:53:30.082914Z",
     "iopub.execute_input": "2024-05-03T04:53:30.083787Z",
     "iopub.status.idle": "2024-05-03T04:53:37.545678Z",
     "shell.execute_reply.started": "2024-05-03T04:53:30.083748Z",
     "shell.execute_reply": "2024-05-03T04:53:37.544622Z"
    },
    "trusted": true,
    "ExecuteTime": {
     "end_time": "2025-03-10T14:46:28.720632Z",
     "start_time": "2025-03-10T14:46:28.720632Z"
    }
   },
   "outputs": [],
   "execution_count": null
  }
 ]
}
